import redis
import pickle
import numpy as np
from . import serialization_tool

'''
    Redis存储结构:
    Hash:
    ai-food-shop:shopId
        repo: 每个dishId对应的特征组向量 repo[dishId] = [array([xx,xx,dtype=float32])]
        all_repo_feats：特征拼接向量
        all_repo_ids：拼接dishId
        repo_count：特征库中的菜品总数
    
    String:
    ai-food-device:deviceId
        shopId
'''

class FooderRedisError(RuntimeError):
    def __init__(self, errMsg):
        self.errMsg = errMsg

    def data(self, errDetails):
        self.errDetails = errDetails
        return self

ErrRedisParamNone = FooderRedisError("param has None")
ErrRedisConnPool = FooderRedisError("conn_pool is null and refresh fail")
ErrRedisRedisNone = FooderRedisError("redis.Redis is None")
ErrRedishNameNotExist = FooderRedisError("hname is not exist")

class RedisUtil:
    def __init__(self, host, password, port):
        if host is None or \
            password is None or\
            port is None:
            raise ErrRedisParamNone
        self.host = host
        self.password = password
        self.port = port
        self.refresh_init()

    def refresh_init(self):
        '''重新初始化'''
        self.conn_pool = redis.ConnectionPool(host=self.host,
                                              password=self.password,
                                              port=self.port,
                                              db=1,
                                              decode_responses=True)

    def hash_get_all_keys(self):
        """获取当前数据库中key的数目"""
        '''获取所有符合规则的key='''
        if self.conn_pool is None:
            self.refresh_init()
        if self.conn_pool is None:
            raise ErrRedisConnPool
        r = redis.Redis(connection_pool=self.conn_pool)
        if r is None:
            raise ErrRedisRedisNone
        print('获取当前数据库中key的数目=', r.dbsize())
        print('获取所有符合规则的key=', r.keys())
        print('获取name=商家A的所有key=', r.hkeys('商家A'))
        # print('获取name=商家A的所有映射键值对=', r.hgetall('商家A'))
        print('获取name=商家A的all_repo_ids键的值=', pickle.loads(r.hget('商家A', 'all_repo_ids').encode('latin1')))
        # print('获取name=商家A的all_repo_feats键的值=', pickle.loads(r.hget('商家A', 'all_repo_feats').encode('latin1')))
        # print('获取name=商家A的repo键的值=', pickle.loads(r.hget('商家A', 'repo').encode('latin1')))
        if 0:
            r.hdel('商家A', 'repo')
            r.hdel('商家A', 'all_repo_ids')
            r.hdel('商家A', 'all_repo_feats')
            r.hdel('商家A', 'repo_count')
    
    def get_shopId_device(self, deviceId):
        '''根据deviceId获取shopId'''
        if deviceId is None:
            raise ErrRedisParamNone
        if self.conn_pool is None:
            self.refresh_init()
        if self.conn_pool is None:
            raise ErrRedisConnPool

        r = redis.Redis(connection_pool=self.conn_pool)
        if r is None:
            raise ErrRedisRedisNone

        hDeviceId = f'ai-food-device:{deviceId}'
        if not r.exists(hDeviceId):
            return None
        
        return r.get(hDeviceId)
            
    def fea_regist(self, shopId, dishId_feaId, fea):
        '''单个菜品特征注册'''
        if dishId_feaId is None or fea is None or shopId is None:
            raise ErrRedisParamNone
        if self.conn_pool is None:
            self.refresh_init()
        if self.conn_pool is None:
            raise ErrRedisConnPool

        r = redis.Redis(connection_pool=self.conn_pool)
        if r is None:
            raise ErrRedisRedisNone

        hShopId = f'ai-food-shop:{shopId}'
        # 1. 获取Redis中的特征库repo, 并解析计算, 特征注册
        if r.hexists(hShopId, 'repo'):
            repo = serialization_tool.strSerializTodict(r.hget(hShopId, 'repo'))
            repo[dishId_feaId] = [fea]
            # if not dishId in repo:
            #     repo[dishId] = [fea]
            # else:
            #     repo[dishId] = np.concatenate((repo[dishId], [fea]), axis=0)
        else:
            repo = {}
            repo[dishId_feaId] = [fea]
        repo_count = len(repo)
        # 2. 计算 all_repo_feats 与 all_repo_ids
        repo_concat = np.concatenate([repo[key] for key in repo], 0)
        all_repo_feats = np.array([x / np.linalg.norm(x) for x in repo_concat])
        all_repo_ids = []
        for id in repo:
            all_repo_ids += [id] * len(repo[id])
        # 3. redis 存储/刷新 repo, all_repo_feats, all_repo_ids, repo_count
        repo = serialization_tool.dictSerializToStr(repo)
        r.hset(hShopId, 'repo', repo)
        all_repo_feats = serialization_tool.dictSerializToStr(all_repo_feats)
        r.hset(hShopId, 'all_repo_feats', all_repo_feats)
        all_repo_ids = serialization_tool.dictSerializToStr(all_repo_ids)
        r.hset(hShopId, 'all_repo_ids', all_repo_ids)
        r.hset(hShopId, 'repo_count', repo_count)

    def feas_search(self, shopId, feats):
        if feats is None or shopId is None:
            raise ErrRedisParamNone
        if self.conn_pool is None:
            self.refresh_init()
        if self.conn_pool is None:
            raise ErrRedisConnPool

        r = redis.Redis(connection_pool=self.conn_pool)
        if r is None:
            raise ErrRedisRedisNone
        
        hShopId = f'ai-food-shop:{shopId}'
        if not r.exists(hShopId, 'all_repo_feats') \
            or not r.exists(hShopId, 'all_repo_ids') \
            or not r.exists(hShopId, 'repo_count'):
            raise ErrRedishNameNotExist

        # 获取Reids中特征库的内存数据，为特征比对做准备.
        all_repo_feats = serialization_tool.strSerializTodict(r.hget(hShopId, 'all_repo_feats'))
        all_repo_ids = serialization_tool.strSerializTodict(r.hget(hShopId, 'all_repo_ids'))
        repo_count = int(r.hget(hShopId, 'repo_count'))
        # cos计算之前正则化
        feats = [x / np.linalg.norm(x) for x in feats]
        # 返回两个数组的矩阵乘积, 余弦相似度计算
        cos_distance = np.matmul(feats, all_repo_feats.transpose((1, 0)))
        ids_score_sorted = np.sort(cos_distance)
        ids_score_sorted = [x[::-1] for x in ids_score_sorted]
        ids_indices_sorted = np.argsort(cos_distance)
        ids_indices_sorted = [x[::-1] for x in ids_indices_sorted]
        ids_results, scores_result = [], []
        for i in range(len(ids_indices_sorted)):
            food_ids, food_scores = [], []
            all_repo_ids_sorted = np.array(all_repo_ids)[list(ids_indices_sorted[i])]
            for j in range(len(all_repo_ids_sorted)):
                if all_repo_ids_sorted[j] not in food_ids:
                    food_ids.append(all_repo_ids_sorted[j])
                    food_scores.append(ids_score_sorted[i][j])
            for x in food_ids:
                ids_results.append(x)
            for x in food_scores:
                scores_result.append(round(float(100 * x), 2))
        return ids_results, scores_result, repo_count

 
    def hash_get_data(self, name, key):
        '''获取Redis中，name=xx, key=xx 的数据'''
        if name is None or key is None:
            raise ErrRedisParamNone
        if self.conn_pool is None:
            self.refresh_init()
        if self.conn_pool is None:
            raise ErrRedisConnPool

        r = redis.Redis(connection_pool=self.conn_pool)
        if r is None:
            raise ErrRedisRedisNone
        if not r.hexists(name, key):
            raise ErrRedishNameNotExist
        return r.hget(name, key)

    def hash_set_data(self, name, key, value):
        '''向 Redis name=xx, key=xx 中 value=xx 的数据'''
        if name is None or key is None or value is None:
            return ErrRedisParamNone
        if self.conn_pool is None:
            self.refresh_init()
        if self.conn_pool is None:
            raise ErrRedisConnPool

        r = redis.Redis(connection_pool=self.conn_pool)
        if r is None:
            raise ErrRedisRedisNone
        return r.hset(name, key, value)
